# Copyright (c) Microsoft. All rights reserved.

"""Base class for group chat orchestrators that manages conversation flow and participant selection."""

import inspect
import logging
import sys
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable, Sequence
from typing import Any

from .._types import ChatMessage
from ._executor import Executor
from ._orchestrator_helpers import ParticipantRegistry
from ._workflow_context import WorkflowContext

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override


logger = logging.getLogger(__name__)


class BaseGroupChatOrchestrator(Executor, ABC):
    """Abstract base class for group chat orchestrators.

    Provides shared functionality for participant registration, routing,
    and round limit checking that is common across all group chat patterns.

    Subclasses must implement pattern-specific orchestration logic while
    inheriting the common participant management infrastructure.
    """

    def __init__(self, executor_id: str) -> None:
        """Initialize base orchestrator.

        Args:
            executor_id: Unique identifier for this orchestrator executor
        """
        super().__init__(executor_id)
        self._registry = ParticipantRegistry()
        # Shared conversation state management
        self._conversation: list[ChatMessage] = []
        self._round_index: int = 0
        self._max_rounds: int | None = None
        self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None

    def register_participant_entry(self, name: str, *, entry_id: str, is_agent: bool) -> None:
        """Record routing details for a participant's entry executor.

        This method provides a unified interface for registering participants
        across all orchestrator patterns, whether they are agents or custom executors.

        Args:
            name: Participant name (used for selection and tracking)
            entry_id: Executor ID for this participant's entry point
            is_agent: Whether this is an AgentExecutor (True) or custom Executor (False)
        """
        self._registry.register(name, entry_id=entry_id, is_agent=is_agent)

    # Conversation state management (shared across all patterns)

    def _append_messages(self, messages: Sequence[ChatMessage]) -> None:
        """Append messages to the conversation history.

        Args:
            messages: Messages to append
        """
        self._conversation.extend(messages)

    def _get_conversation(self) -> list[ChatMessage]:
        """Get a copy of the current conversation.

        Returns:
            Cloned conversation list
        """
        return list(self._conversation)

    def _clear_conversation(self) -> None:
        """Clear the conversation history."""
        self._conversation.clear()

    def _increment_round(self) -> None:
        """Increment the round counter."""
        self._round_index += 1

    async def _check_termination(self) -> bool:
        """Check if conversation should terminate based on termination condition.

        Supports both synchronous and asynchronous termination conditions.

        Returns:
            True if termination condition met, False otherwise
        """
        if self._termination_condition is None:
            return False

        result = self._termination_condition(self._get_conversation())
        if inspect.iscoroutine(result) or inspect.isawaitable(result):
            result = await result
        return bool(result)

    @abstractmethod
    def _get_author_name(self) -> str:
        """Get the author name for orchestrator-generated messages.

        Subclasses must implement this to provide a stable author name
        for completion messages and other orchestrator-generated content.

        Returns:
            Author name to use for messages generated by this orchestrator
        """
        ...

    def _create_completion_message(
        self,
        text: str | None = None,
        reason: str = "completed",
    ) -> ChatMessage:
        """Create a standardized completion message.

        Args:
            text: Optional message text (auto-generated if None)
            reason: Completion reason for default text

        Returns:
            ChatMessage with completion content
        """
        from .._types import Role

        message_text = text or f"Conversation {reason}."
        return ChatMessage(
            role=Role.ASSISTANT,
            text=message_text,
            author_name=self._get_author_name(),
        )

    # Participant routing (shared across all patterns)

    async def _route_to_participant(
        self,
        participant_name: str,
        conversation: list[ChatMessage],
        ctx: WorkflowContext[Any, Any],
        *,
        instruction: str | None = None,
        task: ChatMessage | None = None,
        metadata: dict[str, Any] | None = None,
    ) -> None:
        """Route a conversation to a participant.

        This method handles the dual envelope pattern:
        - AgentExecutors receive AgentExecutorRequest (messages only)
        - Custom executors receive GroupChatRequestMessage (full context)

        Args:
            participant_name: Name of the participant to route to
            conversation: Conversation history to send
            ctx: Workflow context for message routing
            instruction: Optional instruction from manager/orchestrator
            task: Optional task context
            metadata: Optional metadata dict

        Raises:
            ValueError: If participant is not registered
        """
        from ._agent_executor import AgentExecutorRequest
        from ._orchestrator_helpers import prepare_participant_request

        entry_id = self._registry.get_entry_id(participant_name)
        if entry_id is None:
            raise ValueError(f"No registered entry executor for participant '{participant_name}'.")

        if self._registry.is_agent(participant_name):
            # AgentExecutors receive simple message list
            await ctx.send_message(
                AgentExecutorRequest(messages=conversation, should_respond=True),
                target_id=entry_id,
            )
        else:
            # Custom executors receive full context envelope
            request = prepare_participant_request(
                participant_name=participant_name,
                conversation=conversation,
                instruction=instruction or "",
                task=task,
                metadata=metadata,
            )
            await ctx.send_message(request, target_id=entry_id)

    # Round limit enforcement (shared across all patterns)

    def _check_round_limit(self) -> bool:
        """Check if round limit has been reached.

        Uses instance variables _round_index and _max_rounds.

        Returns:
            True if limit reached, False otherwise
        """
        if self._max_rounds is None:
            return False

        if self._round_index >= self._max_rounds:
            logger.warning(
                "%s reached max_rounds=%s; forcing completion.",
                self.__class__.__name__,
                self._max_rounds,
            )
            return True

        return False

    # State persistence (shared across all patterns)

    # State persistence (shared across all patterns)

    @override
    async def on_checkpoint_save(self) -> dict[str, Any]:
        """Capture current orchestrator state for checkpointing.

        Default implementation uses OrchestrationState to serialize common state.
        Subclasses can override this method or _snapshot_pattern_metadata() to add pattern-specific data.

        Returns:
            Serialized state dict
        """
        from ._orchestration_state import OrchestrationState

        state = OrchestrationState(
            conversation=list(self._conversation),
            round_index=self._round_index,
            metadata=self._snapshot_pattern_metadata(),
        )
        return state.to_dict()

    def _snapshot_pattern_metadata(self) -> dict[str, Any]:
        """Serialize pattern-specific state.

        Override this method to add pattern-specific checkpoint data.

        Returns:
            Dict with pattern-specific state (empty by default)
        """
        return {}

    @override
    async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
        """Restore orchestrator state from checkpoint.

        Default implementation uses OrchestrationState to deserialize common state.
        Subclasses can override this method or _restore_pattern_metadata() to restore pattern-specific data.

        Args:
            state: Serialized state dict
        """
        from ._orchestration_state import OrchestrationState

        orch_state = OrchestrationState.from_dict(state)
        self._conversation = list(orch_state.conversation)
        self._round_index = orch_state.round_index
        self._restore_pattern_metadata(orch_state.metadata)

    def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
        """Restore pattern-specific state.

        Override this method to restore pattern-specific checkpoint data.

        Args:
            metadata: Pattern-specific state dict
        """
        pass
